{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/lta/anaconda3/envs/llm_universe_2.x/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n",
    "\n",
    "embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-small-zh-v1.5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.document_loaders.pdf import PyMuPDFLoader\n",
    "\n",
    "# 创建一个 PyMuPDFLoader Class 实例，输入为待加载的 pdf 文档路径\n",
    "loader = PyMuPDFLoader(\"../../../data_base/knowledge_db/pumkin_book/pumpkin_book.pdf\")\n",
    "\n",
    "# 调用 PyMuPDFLoader Class 的函数 load 对 pdf 文件进行加载\n",
    "pdf_pages = loader.load()\n",
    "# 第13页为南瓜书第一页正文，因此从13页开始,从倒数13页涉及敏感用语，因此从-13页结束\n",
    "data_pages = pdf_pages[13:-13]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def clean_text(text: str):\n",
    "    # 删除每页开头与结尾标语及链接\n",
    "    text = re.sub(r'→_→\\n欢迎去各大电商平台选购纸质版南瓜书《机器学习公式详解》\\n←_←', '', text)\n",
    "    text = re.sub(r'→_→\\n配套视频教程：https://www.bilibili.com/video/BV1Mh411e7VU\\n←_←', '', text)\n",
    "    # 删除字符串开头的空格\n",
    "    text = re.sub(r'\\s+', '', text)\n",
    "    # 删除回车\n",
    "    text = re.sub(r'\\n+', '', text)\n",
    "\n",
    "    return text\n",
    "\n",
    "for page in data_pages:\n",
    "    page.page_content = clean_text(page.page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from generate_qa_pairs import docs_generate_pdf_qa_pairs\n",
    "\n",
    "# qa_pairs = docs_generate_pdf_qa_pairs(pdf_pages=train_pages, num_questions_per_page=1)\n",
    "# qa_pairs.save_json(\"train_dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from generate_qa_pairs import QaPairs\n",
    "qa_pairs = QaPairs.from_json('train_dataset.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 119/119 [00:00<00:00, 1890614.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "问答对成功生成率：0.9243697478991597\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "i = 0\n",
    "for qa_pair in tqdm(qa_pairs.qa_pairs):\n",
    "    if len(qa_pair['query']) > 10:\n",
    "        i += 1\n",
    "print('问答对成功生成率：' + str(i/len(qa_pairs.qa_pairs)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.vectorstores.chroma import Chroma\n",
    "def calculat_recall(k: int, vectordb: Chroma):\n",
    "    i = 0\n",
    "    j = 0\n",
    "    for qa_pair in tqdm(qa_pairs.qa_pairs):\n",
    "        if len(qa_pair['query']) > 10:\n",
    "            query = qa_pair['query']\n",
    "            sim_docs = vectordb.similarity_search(query,k=k)\n",
    "            page_nums = [doc.metadata['page'] for doc in sim_docs]\n",
    "            if qa_pair['page_num'] in page_nums: i += 1\n",
    "            j += 1\n",
    "    return i/j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "def evaluate_chunk(chunk_size: int):\n",
    "    # 切分文档\n",
    "    text_splitter = RecursiveCharacterTextSplitter(\n",
    "        chunk_size=chunk_size, chunk_overlap=50, separators=['。', '，', ''])\n",
    "\n",
    "    split_docs = text_splitter.split_documents(data_pages)\n",
    "\n",
    "    # 构建向量库\n",
    "    vectordb = Chroma.from_documents(\n",
    "        documents=split_docs,\n",
    "        embedding=embedding,\n",
    "    )\n",
    "    return [calculat_recall(i, vectordb) for i in [1, 3, 5, 10]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 119/119 [00:07<00:00, 14.95it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 138.61it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 139.66it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 141.72it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 141.97it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 139.12it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 136.24it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 138.65it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 142.85it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 126.86it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 134.26it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 126.83it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 125.34it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 134.81it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 136.61it/s]\n",
      "100%|██████████| 119/119 [00:01<00:00, 118.89it/s]\n"
     ]
    }
   ],
   "source": [
    "chunksize_recall_scores = [evaluate_chunk(i) for i in [200, 300, 400, 500]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "column_names = ['', 'top_1', 'top_3', 'top_5', 'top_10']\n",
    "row_names = ['chunk_200', 'chunk_300', 'chunk_400', 'chunk_500']\n",
    "\n",
    "with open('chunksize_recall.csv', 'w', newline='') as csvfile:\n",
    "    writer = csv.writer(csvfile)\n",
    "    writer.writerow(column_names)\n",
    "    \n",
    "    for i, row in enumerate(chunksize_recall_scores):\n",
    "        writer.writerow([row_names[i]] + row)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm_universe_2.x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.1.-1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
